import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

class MomentumQueue:
    def __init__(self, k, c):
        """
        Initialize a queue with length k and each element is a c-dimensional vector.
        
        Args:
            k (int): The maximum length of the queue.
            c (int): The dimension of each vector.
        """
        self.k = k
        self.c = c
        self.len=0
        self.pt=0
        self.push_time=0
        self.divisor=2*self.k#超过队列最大长度，那么push_time/divisor<1
        self.queue = torch.randn(0, c)  # Initialize with random vectors
        self.times = torch.empty(0, dtype=torch.float32)  # Timestamp storage
        self.queue_times = torch.empty(0, dtype=torch.float32)  # push time storage

    def get_all(self):
        """
        Get all elements from the queue as a [k, c] tensor.
        
        Returns:
            torch.Tensor: A tensor containing all elements in the queue.
        """
        return self.queue

    def push(self, new_elements, new_times):
        """
        Update the queue by replacing the first n elements with new_elements.
        
        Args:
            new_elements (torch.Tensor): A tensor of shape [m, c] where m can be any positive integer.
        """
        batch_size=new_elements.shape[0]
        device = new_elements.device
        self.queue=self.queue.to(device)
        self.times = self.times.to(device)
        self.queue_times = self.queue_times.to(device)

        # Convert times to tensor if needed
        if not isinstance(new_times, torch.Tensor):
            new_times = torch.tensor(new_times, dtype=torch.long, device=device)
        else:
            new_times = new_times.to(device)

        if self.len+batch_size<=self.k:
            self.queue=torch.cat([self.queue,new_elements],dim=0)
            self.times = torch.cat([self.times, new_times], dim=0)
            self.queue_times = torch.cat([self.queue_times, torch.arange(self.push_time, self.push_time + batch_size).to(device) ], dim=0)
            self.len=self.len+batch_size

        elif self.len==self.k:
            if self.pt+batch_size<=self.k:            
                self.queue[self.pt :self.pt + batch_size] = new_elements
                self.times[self.pt:self.pt + batch_size] = new_times
                self.queue_times[self.pt:self.pt + batch_size] = torch.arange(self.push_time, self.push_time + batch_size).to(device)
            else:
                tail_len=self.k-self.pt
                head_len=batch_size-tail_len

                self.queue[self.pt :self.k] = new_elements[:tail_len]
                self.queue[0 :head_len] = new_elements[tail_len:]

                self.times[self.pt:self.k] = new_times[:tail_len]
                self.times[0:head_len] = new_times[tail_len:]

                self.queue_times[self.pt:self.k] = torch.arange(self.push_time, self.push_time + tail_len).to(device)
                self.queue_times[0:head_len] = torch.arange(self.push_time + tail_len, self.push_time + batch_size).to(device)

            self.pt = (self.pt + batch_size) % self.k  # move pointer

        else:
            insert_len=self.k-self.len
            place_len=batch_size-insert_len

            self.queue=torch.cat([self.queue,new_elements[:insert_len,:]],dim=0)
            self.queue[self.pt :self.pt + place_len] = new_elements[insert_len:]

            self.times = torch.cat([self.times, new_times[:insert_len]], dim=0)
            self.times[self.pt:self.pt + place_len] = new_times[insert_len:]

            self.queue_times = torch.cat([self.queue_times, torch.arange(self.push_time, self.push_time + insert_len).to(device)], dim=0)
            self.queue_times[self.pt:self.pt + place_len] = torch.arange(self.push_time + insert_len, self.push_time + batch_size).to(device)

            self.len=self.k
            self.pt = (self.pt + place_len) % self.k  # move pointer

        self.push_time+=batch_size

    def query(self, t_query, num_key):
        """
        Get elements with timestamps closest to the given time t1.
        
        Args:
            t1 (int): Target timestamp (0-1000).
            
        Returns:
            h (int): Number of elements returned.
            elements (torch.Tensor): [h, c] tensor of closest elements.
        """
        if self.len == 0:
            return 0, torch.empty(0, self.c, device=self.queue.device)
        
        # Calculate absolute time differences
        time_diffs = torch.abs(self.times - self.times)########################################
        time_diffs = time_diffs - self.queue_times/self.divisor#timestamp相同时优先选后入队的元素
        
        # Get sorted indices (stable sort preserves original order for ties)
        sorted_indices = torch.argsort(time_diffs, stable=True)
        
        # Select top h elements (or all if fewer than h)
        num_key = min(num_key, self.len)
        closest_indices = sorted_indices[:num_key]
        
        return self.queue[closest_indices]

class ZEncoder(nn.Module):
    def __init__(self, input_len=256, input_d=384, hidden_dim=1024,input_o=384):
        """
        压缩模块：从 [batch, input_len, input_d] 到 [batch, input_len//16, 128]
        
        参数:
            input_len (int): 输入序列长度 (默认为256)
            input_d (int): 输入特征维度 (默认为384)
            hidden_dim (int): 隐藏层维度，默认为512
        """
        super().__init__()
        self.n = input_len
        self.input_o = input_o
        
        # 验证输入长度能被16整除
        if input_len % 16 != 0:
            raise ValueError(f"输入序列长度 {input_len} 必须能被16整除")
        
        # 第一阶段：特征维度压缩 (input_d -> 128)
        # self.feature_compression = nn.Sequential(
        #     nn.Linear(input_d, hidden_dim),
        #     nn.SiLU(),
        #     nn.Linear(hidden_dim, hidden_dim),
        #     nn.SiLU(),
        #     nn.Linear(hidden_dim, self.input_o),
        #     nn.SiLU()
        # )

        self.feature_compression = nn.Sequential(
            nn.Linear(input_d, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.input_o),
            nn.ReLU()
        )
        
        # 第二阶段：序列长度压缩
        self.sequence_compression = nn.Sequential(
            nn.Linear(4 * self.input_o, hidden_dim),  # 输入: 16*128=2048
            nn.ReLU(),
            nn.Linear(hidden_dim, self.input_o),  # 输出: 128
            nn.ReLU()
        )

        # 初始化模型参数
        self._init_weights()

    def _init_weights(self):
        """初始化模型权重"""
        # 初始化特征压缩层的权重
        for layer in self.feature_compression:
            if isinstance(layer, nn.Linear):
                # Kaiming初始化，适合ReLU激活函数
                init.kaiming_normal_(layer.weight, nonlinearity='relu')
                # 偏置初始化为零
                if layer.bias is not None:
                    init.zeros_(layer.bias)
        
        # 初始化序列压缩层的权重
        for layer in self.sequence_compression:
            if isinstance(layer, nn.Linear):
                init.kaiming_normal_(layer.weight, nonlinearity='relu')
                if layer.bias is not None:
                    init.zeros_(layer.bias)

    def forward(self, x):
        """
        输入: 
            x - 形状为 [batch, input_len, input_d] 的张量
        输出: 
            形状为 [batch, input_len//16, 128] 的张量
        """
        # 1. 特征维度压缩: [batch, input_len, input_d] -> [batch, input_len, 128]
        x = self.feature_compression(x)
        
        # 2. 重塑张量: [batch, input_len, 128] -> [batch, input_len//16, 16*128]
        x = x.view(x.size(0), self.n // 4, 4 * 384)
        
        # 3. 序列长度压缩: [batch, input_len//16, 2048] -> [batch, input_len//16, 128]
        x = self.sequence_compression(x)
        
        return x

class DINOEncoder(nn.Module):
    def __init__(self, in_channels=4, feature_dim=128):
        """
        DINO编码器实现
        输入: [batch, in_channels, 32, 32]
        输出: [batch, feature_dim] 的Softmax概率分布
        
        初始化说明:
        - 卷积层: Kaiming初始化(针对ReLU类激活函数)
        - 线性层: Xavier初始化
        - 批归一化层: 权重初始化为1，偏置初始化为0
        """
        super().__init__()
        
        # 卷积特征提取器
        self.conv_net = nn.Sequential(
            # 输入: [batch, 4, 32, 32]
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.GELU(),
            
            # 输出尺寸: [batch, 64, 16, 16]
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.GELU(),
            
            # 输出尺寸: [batch, 128, 8, 8]
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.GELU(),
            
            # 输出尺寸: [batch, 256, 4, 4]
            nn.AdaptiveAvgPool2d(1)  # 全局平均池化 -> [batch, 256, 1, 1]
        )
        
        # 投影头
        self.projection = nn.Sequential(
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Linear(512, feature_dim)
        )
        
        # 添加Softmax层
        self.softmax = nn.Softmax(dim=1)
        
        # 初始化权重
        self._init_weights()
    
    def _init_weights(self):
        """
        初始化模型权重
        """
        # 初始化卷积层权重
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # Kaiming初始化(针对ReLU类激活函数)，GELU是ReLU的平滑近似
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                # Xavier初始化
                init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                # 批归一化层: 权重初始化为1，偏置初始化为0
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
    
    def forward(self, x):
        """
        前向传播过程:
        1. 提取卷积特征
        2. 展平特征向量
        3. 投影到目标维度
        4. 应用Softmax归一化为概率分布
        """
        # 卷积特征提取
        features = self.conv_net(x)  # 输出: [batch, 256, 1, 1]
        # 展平特征
        features = features.view(features.size(0), -1)  # 输出: [batch, 256]
        # 投影到目标维度
        logits = self.projection(features)  # 输出: [batch, 128]
        # 应用Softmax转换为概率分布
        probabilities = self.softmax(logits)
        return probabilities

class DINOEncoder2(nn.Module):
    def __init__(
        self,
        input_dim: int=384,
        embed_dim: int = 384,
        num_heads: int = 8,
        num_layers: int = 4,
        feedforward_dim: int = 512,
        out_dim: int = 128,
        pooling: str = 'mean',  # 'mean' or 'first'
        dropout: float = 0.1
    ):
        """
        Transformer-based Encoder with pooling
        
        Args:
            input_dim: 输入特征维度
            embed_dim: 嵌入/隐藏层维度 (default: 256)
            num_heads: 多头注意力头数 (default: 8)
            num_layers: Transformer层数 (default: 4)
            feedforward_dim: 前馈网络隐藏层维度 (default: 512)
            out_dim: 输出特征维度 (default: 128)
            pooling: 池化策略 ['mean', 'first'] (default: 'mean')
            dropout: Dropout概率 (default: 0.1)
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.pooling = pooling
        
        # 输入投影层
        self.input_proj = nn.Linear(input_dim, embed_dim)
        
        # Transformer编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=feedforward_dim,
            dropout=dropout,
            batch_first=True  # 使用batch_first格式
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )
        
        # 输出投影层
        self.output_proj = nn.Linear(embed_dim, out_dim)
        
        # 初始化权重
        self._init_weights()

    def _init_weights(self):
        """初始化模型权重"""
        # 输入投影层 - Xavier初始化
        nn.init.xavier_uniform_(self.input_proj.weight)
        nn.init.constant_(self.input_proj.bias, 0)
        
        # 输出投影层 - Xavier初始化
        nn.init.xavier_uniform_(self.output_proj.weight)
        nn.init.constant_(self.output_proj.bias, 0)
        
        # Transformer权重已由PyTorch初始化，这里添加额外处理
        for p in self.transformer_encoder.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播
        
        Args:
            x: 输入张量 [batch, seq_len, input_dim]
            
        Returns:
            输出张量 [batch, out_dim]
        """
        # 输入投影
        x = self.input_proj(x)  # [batch, seq_len, embed_dim]
        
        # Transformer编码
        encoded = self.transformer_encoder(x)  # [batch, seq_len, embed_dim]
        
        # 池化策略
        if self.pooling == 'mean':
            # 平均池化
            pooled = encoded.mean(dim=1)  # [batch, embed_dim]
        elif self.pooling == 'first':
            # 取第一个token (CLS token)
            pooled = encoded[:, 0, :]  # [batch, embed_dim]
        else:
            raise ValueError(f"Unsupported pooling: {self.pooling}")
        
        # 输出投影
        output = self.output_proj(pooled)  # [batch, out_dim]
        
        return output

class DynamicClusterHead(nn.Module):
    def __init__(self, input_dim=256*384, hidden_dim=128, num_clusters=1000, momentum=0.99, 
                 init_scale=0.05, center_init_scale=0.1,mid_dim=768):
        """
        动态聚类模块 - 带参数初始化
        
        Args:
            input_dim: 输入特征维度 (128)
            hidden_dim: 映射后的高维空间维度 (2048)
            num_clusters: 聚类中心数量 (1000)
            momentum: 聚类中心的动量更新系数
            init_scale: 投影层权重初始化缩放因子
            center_init_scale: 聚类中心初始化缩放因子
        """
        super().__init__()
        self.num_clusters = num_clusters
        self.momentum = momentum
        
        # 特征映射层
        self.projector = nn.Sequential(
            nn.Linear(input_dim, mid_dim),
            nn.BatchNorm1d(mid_dim),
            nn.SiLU(inplace=True),
            nn.Linear(mid_dim, mid_dim),
            nn.BatchNorm1d(mid_dim),
            nn.SiLU(inplace=True),
            nn.Linear(mid_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim, affine=False)
        )

        self.criterion =torch.nn.CrossEntropyLoss()
        
        # 初始化投影层参数
        self._init_projector_weights(init_scale)
        
        # 初始化聚类中心
        self.register_buffer('centers', torch.randn(num_clusters, hidden_dim) * center_init_scale)
        self.register_buffer('center_counts', torch.ones(num_clusters))
    
    def _init_projector_weights(self, init_scale):
        """初始化多层投影网络的权重"""
        # 第一线性层 (索引0): Kaiming初始化
        nn.init.kaiming_normal_(self.projector[0].weight, 
                            mode='fan_out', 
                            nonlinearity='relu')
        self.projector[0].weight.data.mul_(init_scale)  # 应用缩放因子
        if self.projector[0].bias is not None:
            nn.init.constant_(self.projector[0].bias, 0.0)
        
        # 第一BN层 (索引1): Gamma=1, Beta=0
        if hasattr(self.projector[1], 'weight'):
            nn.init.constant_(self.projector[1].weight, 1.0)
        if hasattr(self.projector[1], 'bias'):
            nn.init.constant_(self.projector[1].bias, 0.0)
        
        # 第二线性层 (索引3): Kaiming初始化
        nn.init.kaiming_normal_(self.projector[3].weight, 
                            mode='fan_out', 
                            nonlinearity='relu')
        self.projector[3].weight.data.mul_(init_scale)  # 应用缩放因子
        if self.projector[3].bias is not None:
            nn.init.constant_(self.projector[3].bias, 0.0)
        
        # 第二BN层 (索引4): Gamma=1, Beta=0
        if hasattr(self.projector[4], 'weight'):
            nn.init.constant_(self.projector[4].weight, 1.0)
        if hasattr(self.projector[4], 'bias'):
            nn.init.constant_(self.projector[4].bias, 0.0)
        
        # 第三线性层 (索引6): Xavier初始化
        nn.init.xavier_normal_(self.projector[6].weight, 
                            gain=nn.init.calculate_gain('relu'))
        self.projector[6].weight.data.mul_(init_scale)  # 应用缩放因子
        if self.projector[6].bias is not None:
            nn.init.constant_(self.projector[6].bias, 0.0)
        
        # 最后一层BN (索引7): affine=False，无需初始化
    
    def forward(self, x, update_centers=True):
        projected = self.projector(x)
        projected_norm = F.normalize(projected, p=2, dim=1)#[bs,dim]
        centers_norm = F.normalize(self.centers, p=2, dim=1)#[num,dim]
        similarity = torch.mm(projected_norm, centers_norm.t())#[bs,num]
        assignments = torch.argmax(similarity, dim=1)#[bs]
        loss = self.compute_cluster_loss(similarity, assignments)

        # self.get_center_diversity()
        
        if self.training and update_centers:
            self.update_centers(projected_norm, assignments)
        
        return loss, assignments, projected_norm  # 返回归一化特征用于分析

    def compute_cluster_loss(self, similarity, assignments):
        sim_to_center = similarity[torch.arange(similarity.size(0)), assignments]
    
        # tau=1.0
        # dispersive_loss = self.criterion(similarity/tau, assignments)


        return -torch.mean(sim_to_center)

    def update_centers(self, features, assignments):
        with torch.no_grad():
            onehot = F.one_hot(assignments, self.num_clusters).float()#[bs,cluster]
            center_sums = torch.mm(onehot.t(), features)#[cluster,dim]
            counts = onehot.sum(dim=0)
            
            # 避免除零错误
            counts = torch.clamp(counts, min=1e-8)#[cluster]
            
            updated_counts = self.momentum * self.center_counts + (1 - self.momentum) * counts
            
            # for i in range(self.num_clusters):
            #     if counts[i] > 0.5:  # 至少分配到半个样本才更新
            #         new_center = center_sums[i] / counts[i]
                    
            #         # EMA更新
            #         self.centers[i] = (
            #             self.momentum * self.center_counts[i] * self.centers[i] + 
            #             (1 - self.momentum) * counts[i] * new_center
            #         ) / updated_counts[i]
            
            new_centers = (self.momentum * self.center_counts[:, None] * self.centers +
              (1 - self.momentum) * center_sums) / updated_counts[:, None]
            self.centers = torch.where(counts[:, None] > 0.5, new_centers, self.centers)

            self.centers = F.normalize(self.centers, p=2, dim=1)
            
            self.center_counts = updated_counts

    def reset_centers(self, scale=0.1):
        """重置聚类中心为随机值"""
        self.centers.normal_(0, scale)
        self.center_counts.fill_(1)
    
    def get_center_diversity(self):
        """计算中心多样性（平均余弦距离）"""
        centers_norm = F.normalize(self.centers, p=2, dim=1)
        similarity = torch.mm(centers_norm, centers_norm.t())

        # print(similarity)
        
        # 对角线置零（排除自相似）
        mask = 1 - torch.eye(self.num_clusters, device=similarity.device)
        a,b=torch.max(similarity * mask,dim=1)
        print(a)
        a,b=torch.min(similarity * mask,dim=1)
        print(a)
        avg_similarity = (similarity * mask).sum() / (self.num_clusters * (self.num_clusters - 1))
        
        return 1 - avg_similarity.item()  # 返回多样性（1 - 平均相似度）

    def get_momentum(self):
        return self.momentum
    
    def set_momentum(self,momentum):
        self.momentum=momentum

class ParallelClusterHead(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=128, num_groups=256, 
                 num_clusters=1000, momentum=0.99, init_scale=0.05):
        """
        并行多组动态聚类模块
        
        Args:
            input_dim: 每组输入特征维度 (128)
            hidden_dim: 映射后的高维空间维度 (2048)
            num_groups: 组数量 (16)
            num_clusters: 每组聚类中心数量 (1000)
            momentum: 聚类中心的动量更新系数
            init_scale: 投影层权重初始化缩放因子
        """
        super().__init__()
        self.num_groups = num_groups
        self.num_clusters = num_clusters
        self.momentum = momentum
        
        # 特征映射层 - 共享权重处理所有组
        self.projector = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.SiLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.SiLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim, affine=False)
        )

        self.criterion =torch.nn.CrossEntropyLoss()

        self._init_projector_weights(init_scale)
        
        # 初始化聚类中心 [组数, 聚类数, 特征维度]
        self.register_buffer('centers', 
                            torch.randn(num_groups, num_clusters, hidden_dim) * init_scale)
        
        # 初始化聚类中心计数器 [组数, 聚类数]
        self.register_buffer('center_counts', 
                            torch.ones(num_groups, num_clusters))
    
    def _init_projector_weights(self, init_scale):
        """初始化投影层权重"""
        # 第一线性层: Kaiming初始化
        nn.init.kaiming_normal_(self.projector[0].weight, 
                            mode='fan_out', 
                            nonlinearity='relu')
        self.projector[0].weight.data.mul_(init_scale)  # 应用缩放因子
        if self.projector[0].bias is not None:
            nn.init.constant_(self.projector[0].bias, 0.0)
        
        # 第一BN层 (索引1): Gamma=1, Beta=0
        if hasattr(self.projector[1], 'weight'):
            nn.init.constant_(self.projector[1].weight, 1.0)
        if hasattr(self.projector[1], 'bias'):
            nn.init.constant_(self.projector[1].bias, 0.0)
        
        # 第二线性层 (索引3): Kaiming初始化
        nn.init.kaiming_normal_(self.projector[3].weight, 
                            mode='fan_out', 
                            nonlinearity='relu')
        self.projector[3].weight.data.mul_(init_scale)  # 应用缩放因子
        if self.projector[3].bias is not None:
            nn.init.constant_(self.projector[3].bias, 0.0)
        
        # 第二BN层 (索引4): Gamma=1, Beta=0
        if hasattr(self.projector[4], 'weight'):
            nn.init.constant_(self.projector[4].weight, 1.0)
        if hasattr(self.projector[4], 'bias'):
            nn.init.constant_(self.projector[4].bias, 0.0)
        
        # 第三线性层 (索引6): Xavier初始化
        nn.init.xavier_normal_(self.projector[6].weight, 
                            gain=nn.init.calculate_gain('relu'))
        self.projector[6].weight.data.mul_(init_scale)  # 应用缩放因子
        if self.projector[6].bias is not None:
            nn.init.constant_(self.projector[6].bias, 0.0)
        
        # 最后一层BN (索引7): affine=False，无需初始化
    
    def forward(self, x, update_centers=True):
        """
        Args:
            x: 输入特征 [batch_size, num_groups, input_dim]
        Returns:
            loss: 聚类损失标量
            assignments: 分配结果 [batch_size, num_groups]
            features_norm: 归一化特征 [batch_size, num_groups, hidden_dim]
        """
        batch_size = x.size(0)
        
        # 重组特征: [batch, groups, dim] -> [batch*groups, dim]
        flat_x = x.reshape(-1, x.size(-1))
        
        # 特征映射: [batch*groups, dim] -> [batch*groups, hidden_dim]
        projected = self.projector(flat_x)
        
        # 重组为分组格式: [batch, groups, hidden_dim]
        features = projected.reshape(batch_size, self.num_groups, -1)
        
        # L2归一化特征
        features_norm = F.normalize(features, p=2, dim=-1)
        
        # 计算相似度: [batch, groups, clusters]
        # 中心已归一化 (在update_centers中处理)
        centers_norm = F.normalize(self.centers, p=2, dim=-1)
        
        # 使用einsum高效计算相似度
        similarity = torch.einsum('bgf,gcf->bgc', features_norm, centers_norm)#[bs,group,cluster]
        
        # 为每个样本的每组特征找到最近的中心索引
        assignments = torch.argmax(similarity, dim=-1)  # [batch_size, num_groups]
        
        # 计算聚类损失
        loss = self.compute_cluster_loss(similarity, assignments)
        
        # 动态更新聚类中心
        if self.training and update_centers:
            self.update_centers(features_norm, assignments)
        
        return loss, assignments, features_norm

    def compute_cluster_loss(self, similarity, assignments):
        """
        计算并行聚类损失
        
        Args:
            similarity: 相似度矩阵 [batch, groups, clusters]
            assignments: 分配结果 [batch, groups]
        """
        # # 创建索引: [batch, groups]
        # batch_idx = torch.arange(similarity.size(0))[:, None].expand(-1, self.num_groups)
        # group_idx = torch.arange(self.num_groups)[None, :].expand(similarity.size(0), -1)
        
        # # 提取每个样本每组特征与最近中心的相似度
        # sim_to_center = similarity[batch_idx, group_idx, assignments]
        
        # # 损失函数: 最小化负相似度
        # loss = -torch.mean(sim_to_center)

        batch_size, num_groups, num_clusters = similarity.shape

        # 重塑为 [batch * groups, clusters]
        similarity_flat = similarity.contiguous().view(-1, num_clusters)
        
        # 重塑为 [batch * groups]
        assignments_flat = assignments.contiguous().view(-1).long()

        tau=1.0
        
        # 并行计算所有组和样本的交叉熵损失
        dispersive_loss = self.criterion(similarity_flat/tau, assignments_flat)

        return dispersive_loss
    
    def update_centers(self, features, assignments):
        """
        向量化更新所有组的聚类中心
        """
        with torch.no_grad():
            batch_size = features.size(0)
            
            # 重组数据: [batch, groups, dim] -> [batch*groups, dim]
            flat_features = features.reshape(-1, features.size(-1))
            flat_assignments = assignments.reshape(-1)
            
            # 为每个组创建独立索引
            group_indices = torch.arange(self.num_groups, device=features.device)
            group_indices = group_indices.repeat(batch_size)  # [0,1,2,...,15,0,1,...]
            
            # 计算每个样本的全局中心索引
            global_center_idx = group_indices * self.num_clusters + flat_assignments
            
            # 为每个组和聚类中心计算特征和
            center_sums = torch.zeros(
                self.num_groups * self.num_clusters, 
                features.size(-1), 
                device=features.device
            )
            center_sums.index_add_(
                0, 
                global_center_idx, 
                flat_features
            )
            
            # 为每个组和聚类中心计算样本计数
            counts = torch.zeros(
                self.num_groups * self.num_clusters, 
                device=features.device
            )
            counts.index_add_(
                0, 
                global_center_idx, 
                torch.ones_like(global_center_idx, dtype=torch.float)
            )
            
            # 重塑为 [groups, clusters, dim] 和 [groups, clusters]
            center_sums = center_sums.view(self.num_groups, self.num_clusters, -1)
            counts = counts.view(self.num_groups, self.num_clusters)
            
            # 避免除零错误
            counts = counts.clamp(min=1e-8)
            
            # 计算当前批次的平均特征
            new_centers = center_sums / counts.unsqueeze(-1)
            
            # 更新中心计数 (EMA)
            updated_counts = (
                self.momentum * self.center_counts 
                + (1 - self.momentum) * counts
            )
            
            # 向量化EMA更新所有中心
            # 计算更新掩码 (至少分配到0.5个样本)
            update_mask = (counts > 0.5).unsqueeze(-1)
            
            # EMA更新公式
            updated_centers = (
                self.momentum * self.center_counts.unsqueeze(-1) * self.centers +
                (1 - self.momentum) * counts.unsqueeze(-1) * new_centers
            ) / updated_counts.unsqueeze(-1)
            
            # 应用更新掩码
            self.centers = torch.where(
                update_mask, 
                updated_centers, 
                self.centers
            )
            
            # 更新中心计数器
            self.center_counts = updated_counts
            
            # 更新后归一化所有中心
            self.centers = F.normalize(self.centers, p=2, dim=-1)

    def get_group_diversity(self, group_idx):
        """获取指定组的中心多样性"""
        centers = self.centers[group_idx]
        centers_norm = F.normalize(centers, p=2, dim=1)
        similarity = torch.mm(centers_norm, centers_norm.t())
        
        # 对角线置零
        mask = 1 - torch.eye(self.num_clusters, device=similarity.device)
        avg_similarity = (similarity * mask).sum() / (self.num_clusters * (self.num_clusters - 1))
        
        return 1 - avg_similarity.item()


def warmup_centers(cluster_head, dataloader, num_batches=10, momentum=0.9):
    """聚类中心预热函数"""
    original_mode = cluster_head.training
    original_momentum = cluster_head.momentum
    
    # 临时提高动量更新速度
    cluster_head.momentum = momentum
    cluster_head.eval()
    
    with torch.no_grad():
        print(f"开始聚类中心预热 ({num_batches} batches, momentum={momentum})...")
        batch_count = 0
        for data in dataloader:
            features = data[0] if isinstance(data, (list, tuple)) else data
            
            # 处理不同输入维度
            if features.dim() == 4:  # 图像数据
                features = features.flatten(start_dim=1)
            elif features.dim() > 2:
                features = features.view(features.size(0), -1)
            
            # 特征映射和归一化
            projected = cluster_head.projector(features)
            projected_norm = F.normalize(projected, p=2, dim=1)
            
            # 计算相似度并分配
            centers_norm = F.normalize(cluster_head.centers, p=2, dim=1)
            similarity = torch.mm(projected_norm, centers_norm.t())
            assignments = torch.argmax(similarity, dim=1)
            
            # 更新中心
            cluster_head.update_centers(projected_norm, assignments)
            
            batch_count += 1
            if batch_count >= num_batches:
                break
                
        print(f"预热完成! 更新了{num_batches}个批次")
        print(f"中心范数范围: {torch.norm(cluster_head.centers, dim=1).min().item():.4f} - "
              f"{torch.norm(cluster_head.centers, dim=1).max().item():.4f}")
        print(f"中心多样性: {cluster_head.get_center_diversity():.4f}")
    
    # 恢复原始模式和动量
    cluster_head.momentum = original_momentum
    if original_mode:
        cluster_head.train()

def analyze_cluster_state(cluster_head, dataloader, num_batches=5):
    """分析聚类状态"""
    original_mode = cluster_head.training
    cluster_head.eval()
    
    assignment_counts = torch.zeros(cluster_head.num_clusters, device=cluster_head.centers.device)
    avg_sim_to_center = 0.0
    batch_count = 0
    
    with torch.no_grad():
        for data in dataloader:
            features = data[0] if isinstance(data, (list, tuple)) else data
            
            if features.dim() == 4:
                features = features.flatten(start_dim=1)
            elif features.dim() > 2:
                features = features.view(features.size(0), -1)
            
            projected = cluster_head.projector(features)
            projected_norm = F.normalize(projected, p=2, dim=1)
            centers_norm = F.normalize(cluster_head.centers, p=2, dim=1)
            similarity = torch.mm(projected_norm, centers_norm.t())
            assignments = torch.argmax(similarity, dim=1)
            
            # 统计分配情况
            batch_counts = torch.bincount(assignments, minlength=cluster_head.num_clusters)
            assignment_counts += batch_counts
            
            # 计算平均相似度
            sim_to_center = similarity[torch.arange(similarity.size(0)), assignments]
            avg_sim_to_center += sim_to_center.mean().item()
            
            batch_count += 1
            if batch_count >= num_batches:
                break
    
    # 计算未使用中心比例
    unused_centers = (assignment_counts == 0).sum().item()
    unused_ratio = unused_centers / cluster_head.num_clusters
    
    print("\n===== 聚类状态分析 =====")
    print(f"已分析批次: {batch_count}")
    print(f"样本分配情况: {assignment_counts.min().item()}-{assignment_counts.max().item()}")
    print(f"未使用中心比例: {unused_ratio:.2%} ({unused_centers}/{cluster_head.num_clusters})")
    print(f"平均相似度: {avg_sim_to_center / batch_count:.4f}")
    print(f"中心多样性: {cluster_head.get_center_diversity():.4f}")
    
    if original_mode:
        cluster_head.train()

# Example usage
if __name__ == "__main__":
    # Create a queue with length 5 and each element is 3-dimensional
    queue = MomentumQueue(5, 3)
    print("Initial queue:")
    print(queue.get_all())

    # Create new elements to update the first 2 elements of the queue
    new_elements = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
    queue.push(new_elements)
    print("\nQueue after updating first 2 elements:")
    print(queue.get_all())

    # Update the first 3 elements
    new_elements = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
    queue.push(new_elements)
    print("\nQueue after updating first 3 elements:")
    print(queue.get_all())

    new_elements = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
    queue.push(new_elements)
    print("\nQueue after updating first 3 elements:")
    print(queue.get_all())